{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.OmniScaleCNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OmniScaleCNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">This is an unofficial PyTorch implementation created by Ignacio Oguiza - oguiza@timeseriesAI.co"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "#This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@timeseriesAI.co based on:\n",
    "# Rußwurm, M., & Körner, M. (2019). Self-attention for raw optical satellite time series classification. arXiv preprint arXiv:1910.10536.\n",
    "# Official implementation: https://github.com/dl4sits/BreizhCrops/blob/master/breizhcrops/models/OmniScaleCNN.py\n",
    "\n",
    "class SampaddingConv1D_BN(Module):\n",
    "    def __init__(self, in_channels, out_channels, kernel_size):\n",
    "        self.padding = nn.ConstantPad1d((int((kernel_size - 1) / 2), int(kernel_size / 2)), 0)\n",
    "        self.conv1d = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)\n",
    "        self.bn = nn.BatchNorm1d(num_features=out_channels)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.padding(x)\n",
    "        x = self.conv1d(x)\n",
    "        x = self.bn(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class build_layer_with_layer_parameter(Module):\n",
    "    \"\"\"\n",
    "    formerly build_layer_with_layer_parameter\n",
    "    \"\"\"\n",
    "    def __init__(self, layer_parameters):\n",
    "        \"\"\"\n",
    "        layer_parameters format\n",
    "            [in_channels, out_channels, kernel_size,\n",
    "            in_channels, out_channels, kernel_size,\n",
    "            ..., nlayers\n",
    "            ]\n",
    "        \"\"\"\n",
    "        self.conv_list = nn.ModuleList()\n",
    "\n",
    "        for i in layer_parameters:\n",
    "            # in_channels, out_channels, kernel_size\n",
    "            conv = SampaddingConv1D_BN(i[0], i[1], i[2])\n",
    "            self.conv_list.append(conv)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        conv_result_list = []\n",
    "        for conv in self.conv_list:\n",
    "            conv_result = conv(x)\n",
    "            conv_result_list.append(conv_result)\n",
    "\n",
    "        result = F.relu(torch.cat(tuple(conv_result_list), 1))\n",
    "        return result\n",
    "\n",
    "\n",
    "class OmniScaleCNN(Module):\n",
    "    def __init__(self, c_in, c_out, seq_len, layers=[8 * 128, 5 * 128 * 256 + 2 * 256 * 128], few_shot=False):\n",
    "\n",
    "        receptive_field_shape = seq_len//4\n",
    "        layer_parameter_list = generate_layer_parameter_list(1,receptive_field_shape, layers, in_channel=c_in)\n",
    "        self.few_shot = few_shot\n",
    "        self.layer_parameter_list = layer_parameter_list\n",
    "        self.layer_list = []\n",
    "        for i in range(len(layer_parameter_list)):\n",
    "            layer = build_layer_with_layer_parameter(layer_parameter_list[i])\n",
    "            self.layer_list.append(layer)\n",
    "        self.net = nn.Sequential(*self.layer_list)\n",
    "        self.gap = GAP1d(1)\n",
    "        out_put_channel_number = 0\n",
    "        for final_layer_parameters in layer_parameter_list[-1]:\n",
    "            out_put_channel_number = out_put_channel_number + final_layer_parameters[1]\n",
    "        self.hidden = nn.Linear(out_put_channel_number, c_out)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.net(x)\n",
    "        x = self.gap(x)\n",
    "        if not self.few_shot: x = self.hidden(x)\n",
    "        return x\n",
    "\n",
    "def get_Prime_number_in_a_range(start, end):\n",
    "    Prime_list = []\n",
    "    for val in range(start, end + 1):\n",
    "        prime_or_not = True\n",
    "        for n in range(2, val):\n",
    "            if (val % n) == 0:\n",
    "                prime_or_not = False\n",
    "                break\n",
    "        if prime_or_not:\n",
    "            Prime_list.append(val)\n",
    "    return Prime_list\n",
    "\n",
    "\n",
    "def get_out_channel_number(paramenter_layer, in_channel, prime_list):\n",
    "    out_channel_expect = max(1, int(paramenter_layer / (in_channel * sum(prime_list))))\n",
    "    return out_channel_expect\n",
    "\n",
    "\n",
    "def generate_layer_parameter_list(start, end, layers, in_channel=1):\n",
    "    prime_list = get_Prime_number_in_a_range(start, end)\n",
    "\n",
    "    layer_parameter_list = []\n",
    "    for paramenter_number_of_layer in layers:\n",
    "        out_channel = get_out_channel_number(paramenter_number_of_layer, in_channel, prime_list)\n",
    "\n",
    "        tuples_in_layer = []\n",
    "        for prime in prime_list:\n",
    "            tuples_in_layer.append((in_channel, out_channel, prime))\n",
    "        in_channel = len(prime_list) * out_channel\n",
    "\n",
    "        layer_parameter_list.append(tuples_in_layer)\n",
    "\n",
    "    tuples_in_layer_last = []\n",
    "    first_out_channel = len(prime_list) * get_out_channel_number(layers[0], 1, prime_list)\n",
    "    tuples_in_layer_last.append((in_channel, first_out_channel, 1))\n",
    "    tuples_in_layer_last.append((in_channel, first_out_channel, 2))\n",
    "    layer_parameter_list.append(tuples_in_layer_last)\n",
    "    return layer_parameter_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OmniScaleCNN(\n",
       "  (net): Sequential(\n",
       "    (0): build_layer_with_layer_parameter(\n",
       "      (conv_list): ModuleList(\n",
       "        (0): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 0), value=0)\n",
       "          (conv1d): Conv1d(3, 56, kernel_size=(1,), stride=(1,))\n",
       "          (bn): BatchNorm1d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (1): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 1), value=0)\n",
       "          (conv1d): Conv1d(3, 56, kernel_size=(2,), stride=(1,))\n",
       "          (bn): BatchNorm1d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (2): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(1, 1), value=0)\n",
       "          (conv1d): Conv1d(3, 56, kernel_size=(3,), stride=(1,))\n",
       "          (bn): BatchNorm1d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): build_layer_with_layer_parameter(\n",
       "      (conv_list): ModuleList(\n",
       "        (0): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 0), value=0)\n",
       "          (conv1d): Conv1d(168, 227, kernel_size=(1,), stride=(1,))\n",
       "          (bn): BatchNorm1d(227, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (1): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 1), value=0)\n",
       "          (conv1d): Conv1d(168, 227, kernel_size=(2,), stride=(1,))\n",
       "          (bn): BatchNorm1d(227, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (2): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(1, 1), value=0)\n",
       "          (conv1d): Conv1d(168, 227, kernel_size=(3,), stride=(1,))\n",
       "          (bn): BatchNorm1d(227, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (2): build_layer_with_layer_parameter(\n",
       "      (conv_list): ModuleList(\n",
       "        (0): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 0), value=0)\n",
       "          (conv1d): Conv1d(681, 510, kernel_size=(1,), stride=(1,))\n",
       "          (bn): BatchNorm1d(510, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (1): SampaddingConv1D_BN(\n",
       "          (padding): ConstantPad1d(padding=(0, 1), value=0)\n",
       "          (conv1d): Conv1d(681, 510, kernel_size=(2,), stride=(1,))\n",
       "          (bn): BatchNorm1d(510, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (gap): GAP1d(\n",
       "    (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (flatten): Flatten(full=False)\n",
       "  )\n",
       "  (hidden): Linear(in_features=1020, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "c_in = 3\n",
    "seq_len = 12\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, c_in, seq_len)\n",
    "m = create_model(OmniScaleCNN, c_in, c_out, seq_len)\n",
    "test_eq(OmniScaleCNN(c_in, c_out, seq_len)(xb).shape, [bs, c_out])\n",
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/109_models.OmniScaleCNN.ipynb saved at 2022-11-09 13:10:17\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 13:10:20 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
